/**
 * Copyright (C) 2024 AIDC-AI
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.langengine.deepsearch.agent;

import com.alibaba.langengine.core.chatmodel.BaseChatModel;
import com.alibaba.langengine.core.indexes.Document;
import com.alibaba.langengine.core.messages.BaseMessage;
import com.alibaba.langengine.core.messages.HumanMessage;
import com.alibaba.langengine.core.vectorstore.VectorStore;
import com.alibaba.langengine.deepsearch.utils.OutputParserUtils;
import com.alibaba.langengine.deepsearch.utils.VectorStoreUtils;
import com.alibaba.langengine.deepsearch.vectorstore.*;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

@Slf4j
public class ChainOfRAGAgent extends RAGAgent {

    private BaseChatModel llm;
    private VectorStore vectorDb;

    private static final String FOLLOWUP_QUERY_PROMPT =
            "You are using a search tool to answer the main query by iteratively searching the database. Given the following intermediate queries and answers, generate a new simple follow-up question that can help answer the main query. You may rephrase or decompose the main query when previous answers are not helpful. Ask simple follow-up questions only as the search tool may not understand complex questions.\n\n" +
                    "## Previous intermediate queries and answers\n%s\n\n" +
                    "## Main query to answer\n%s\n\n" +
                    "Respond with a simple follow-up question that will help answer the main query, do not explain yourself or output anything else.";

    private static final String INTERMEDIATE_ANSWER_PROMPT =
            "Given the following documents, generate an appropriate answer for the query. DO NOT hallucinate any information, only use the provided documents to generate the answer. Respond “No relevant information found” if the documents do not contain useful information.\n\n" +
                    "## Documents\n%s\n\n" +
                    "## Query\n%s\n\n" +
                    "Respond with a concise answer only, do not explain yourself or output anything else.";

    private static final String FINAL_ANSWER_PROMPT =
            "Given the following intermediate queries and answers, generate a final answer for the main query by combining relevant information. Note that intermediate answers are generated by an LLM and may not always be accurate.\n\n" +
                    "## Documents\n%s\n\n" +
                    "## Intermediate queries and answers\n%s\n\n" +
                    "## Main query\n%s\n\n" +
                    "Respond with an appropriate answer only, do not explain yourself or output anything else.";

    private static final String GET_SUPPORTED_DOCS_PROMPT =
            "Given the following documents, select the ones that are support the Q-A pair.\n\n" +
                    "## Documents\n%s\n\n" +
                    "## Q-A Pair\n### Question\n%s\n### Answer\n%s\n\n" +
                    "Respond with a python list of indices of the selected documents.";

    public ChainOfRAGAgent(BaseChatModel llm, VectorStore vectorDb) {
        this.llm = llm;
        this.vectorDb = vectorDb;
    }

    private SubQueryResult reflectGetSubquery(String query, List<String> intermediateContext) {
        String content = String.format(FOLLOWUP_QUERY_PROMPT, String.join("\n", intermediateContext), query);
        List<BaseMessage> messages = new ArrayList<>();
        HumanMessage humanMessage = new HumanMessage();
        humanMessage.setContent(content);
        messages.add(humanMessage);
        BaseMessage chatResponse = llm.run(messages);
        return new SubQueryResult(chatResponse.getContent(), chatResponse.getTotalTokens());
    }

    private AnswerResult retrieveAndAnswer(String query, Map<String, Object> kwargs) {
        Long consumeTokens = 0L;
        List<Document> allRetrievedResults = new ArrayList<>();
        List<Document> retrievalResults = vectorDb.similaritySearch(query, 5);
        allRetrievedResults.addAll(retrievalResults);

        allRetrievedResults = VectorStoreUtils.deduplicateResults(allRetrievedResults);

        String content = String.format(INTERMEDIATE_ANSWER_PROMPT, formatRetrievedResults(allRetrievedResults), query);
        List<BaseMessage> messages = new ArrayList<>();
        HumanMessage humanMessage = new HumanMessage();
        humanMessage.setContent(content);
        messages.add(humanMessage);
        BaseMessage chatResponse = llm.run(messages);
        return new AnswerResult(chatResponse.getContent(), allRetrievedResults, consumeTokens + chatResponse.getTotalTokens());
    }

    private SupportedDocsResult getSupportedDocs(List<Document> retrievedResults, String query, String intermediateAnswer) {
        List<Document> supportedRetrievedResults = new ArrayList<>();
        Long tokenUsage = 0L;
        if (!intermediateAnswer.contains("No relevant information found")) {
            String content = String.format(GET_SUPPORTED_DOCS_PROMPT, formatRetrievedResults(retrievedResults), query, intermediateAnswer);

            List<BaseMessage> messages = new ArrayList<>();
            HumanMessage humanMessage = new HumanMessage();
            humanMessage.setContent(content);
            messages.add(humanMessage);
            BaseMessage chatResponse = llm.run(messages);

            List<String> indices = OutputParserUtils.literalEval(chatResponse.getContent());
            for (String indexStr : indices) {
                supportedRetrievedResults.add(retrievedResults.get(Integer.parseInt(indexStr)));
            }
            tokenUsage = chatResponse.getTotalTokens();
        }
        return new SupportedDocsResult(supportedRetrievedResults, tokenUsage);
    }

    @Override
    public String getDescription() {
        return  "This agent can decompose complex queries and gradually find the fact information of sub-queries. It is very suitable for handling concrete factual queries and multi-hop questions.";
    }

    @Override
    public RetrievalResultData retrieve(String query, Map<String, Object> kwargs) {
        int maxIter = 3;
        if(kwargs.get("maxIter") != null) {
            maxIter = Integer.parseInt(kwargs.get("maxIter").toString());
        }
        List<String> intermediateContexts = new ArrayList<>();
        List<Document> allRetrievedResults = new ArrayList<>();
        Long tokenUsage = 0L;
        for (int iter = 0; iter < maxIter; iter++) {
            log.info(">> Iteration: " + (iter + 1) + "\n");
            SubQueryResult subQueryResult = reflectGetSubquery(query, intermediateContexts);
            AnswerResult answerResult = retrieveAndAnswer(subQueryResult.getQuery(), kwargs);
            SupportedDocsResult supportedDocsResult = getSupportedDocs(answerResult.getRetrievedResults(), subQueryResult.getQuery(), answerResult.getAnswer());

            allRetrievedResults.addAll(supportedDocsResult.getSupportedRetrievedResults());
            intermediateContexts.add("Intermediate query" + (intermediateContexts.size() + 1) + ": " + subQueryResult.getQuery() + "\nIntermediate answer" + (intermediateContexts.size() + 1) + ": " + answerResult.getAnswer());
            tokenUsage += subQueryResult.getTokenCount() + answerResult.getTokenUsage() + supportedDocsResult.getTokenUsage();
        }
        allRetrievedResults = VectorStoreUtils.deduplicateResults(allRetrievedResults);
        RetrievalResultData data = new RetrievalResultData(allRetrievedResults, tokenUsage);
        data.setAdditionalInfo(Collections.singletonMap("intermediate_context", intermediateContexts));
        return data;
    }

    @Override
    public RetrievalResultData query(String query, Map<String, Object> kwargs) {
        RetrievalResultData retrievalResultData = retrieve(query, kwargs);
        List<Document> allRetrievedResults = retrievalResultData.getDocuments();
        List<String> intermediateContext = (List<String>) retrievalResultData.getAdditionalInfo().get("intermediate_context");

        log.info(" Summarize answer from all " + allRetrievedResults.size() + " retrieved chunks... \n");
        String content = String.format(FINAL_ANSWER_PROMPT, formatRetrievedResults(allRetrievedResults), String.join("\n", intermediateContext), query);

        List<BaseMessage> messages = new ArrayList<>();
        HumanMessage humanMessage = new HumanMessage();
        humanMessage.setContent(content);
        messages.add(humanMessage);
        BaseMessage chatResponse = llm.run(messages);

        log.info("\n==== FINAL ANSWER====\n");
        log.info(chatResponse.getContent());

        return new RetrievalResultData(chatResponse.getContent(), allRetrievedResults, retrievalResultData.getConsumeTokens() + chatResponse.getTotalTokens());
    }

    private String formatRetrievedResults(List<Document> retrievedResults) {
        StringBuilder formattedDocuments = new StringBuilder();
        for (int i = 0; i < retrievedResults.size(); i++) {
            Document result = retrievedResults.get(i);
            String text = result.getPageContent();
            formattedDocuments.append("<Document ").append(i).append(">\n").append(text).append("\n<\\Document ").append(i).append(">");
        }
        return formattedDocuments.toString();
    }
}
